import torch
import torch.nn as nn
from torch import Tensor
import numpy as np
torch.autograd.set_detect_anomaly(True)


class ProfileBuilder(nn.Module):
    def __init__(
            self,
            num_src: int,
            num_dst: int,
            msg_dim: int,
            storage=None,
            model_dim: int = None,
            history_len: int = None,
            history_version: int = None,
            profile_on: str = None,
            profile_version: int = None,
            time_encoding_version: str = None,
            readout_version: int = None,
            device=None,
            **kwargs
    ):
        super().__init__()

        self.num_src = num_src
        self.num_dst = num_dst
        self.msg_dim = msg_dim

        self.storage = storage.to(device)

        self.model_dim = model_dim
        self.history_len = history_len
        self.history_version = history_version
        self.profile_on = profile_on
        self.profile_version = profile_version

        self.dst_encoder = XavierEmbedding((self.num_dst + 1, self.model_dim), padding_idx=0).to(device)
        self.msg_encoder = nn.Linear(in_features=self.msg_dim, out_features=self.model_dim).to(device)
        self.time_encoder = TimeEncoder(out_features=self.model_dim, version=time_encoding_version, device=device).to(device)
        # self.record_encoder = nn.Embedding((2, self.model_dim)).to(device)

        # if self.history_version == "dst-msg-t-y":
        #     self.hist_lin = nn.Linear(in_features=3*self.model_dim + 1, out_features=self.model_dim).to(device)
        # elif self.history_version == "dst-msg-y":
        #     self.hist_lin = nn.Linear(in_features=2*self.model_dim + 1, out_features=self.model_dim).to(device)
        # elif self.history_version == "dst-y":
        #     self.hist_lin = nn.Linear(in_features=self.model_dim + 1, out_features=self.model_dim).to(device)
        # else:
        #     raise NotImplementedError

        self.readout = Readout(in_features=self.model_dim, out_features=1, version=readout_version).to(device)
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.device = device
        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'dst_encoder'):
            self.dst_encoder.reset_parameters()
        if hasattr(self, 'msg_encoder'):
            self.msg_encoder.reset_parameters()
        if hasattr(self, 'time_encoder'):
            self.time_encoder.reset_parameters()
        if hasattr(self, 'record_encoder'):
            self.record_encoder.reset_parameters()

    def reset_state(self):
        pass

    def detach(self):
        pass

    def forward(self, batch):
        batch = batch.to(self.device)

        """ History Based Profiler """
        profile_emb = self.summarize_profile_v4(batch)

        """ Recent Event """
        dst_emb = self.dst_encoder(batch.dst)
        msg_emb = self.msg_encoder(batch.msg)
        time_emb = self.time_encoder(batch.t)

        """ Readout """
        logits = self.readout(profile_emb, dst_emb, msg_emb, time_emb)
        y = batch.y.float().unsqueeze(-1)
        loss = self.criterion(logits, y)
        y_pred = logits.detach().cpu().sigmoid()
        y_true = batch.y.detach().cpu().float().unsqueeze(-1)

        return loss, y_pred, y_true

    def summarize_profile(self, batch):
        current_time = batch.t.min()
        history = self.storage[self.storage.t < current_time]
        if len(history) == 0:
            return torch.zeros((batch.src.size(0), self.model_dim), device=self.device)
        if len(history) > self.history_len:
            t0 = torch.topk(history.t, self.history_len)[0].min()
            history = history[history.t >= t0]

        dst_emb = self.dst_encoder(history.dst)
        record_emb = history.y.unsqueeze(-1)
        if self.history_version == "dst-msg-t-y":
            msg_emb = self.msg_encoder(history.msg)
            time_emb = self.time_encoder(history.t)
            history_emb = torch.cat((dst_emb, msg_emb, time_emb, record_emb), dim=1)
        elif self.history_version == "dst-msg-y":
            msg_emb = self.msg_encoder(history.msg)
            history_emb = torch.cat((dst_emb, msg_emb, record_emb), dim=1)
        elif self.history_version == "dst-y":
            history_emb = torch.cat((dst_emb, record_emb), dim=1)
        else:
            raise NotImplementedError
        history_emb = self.hist_lin(history_emb)

        if self.profile_on == "src":
            mask = (batch.src.view(-1, 1) == history.src.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
        elif self.profile_on == "dst":
            mask = (batch.dst.view(-1, 1) == history.dst.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
        elif self.profile_on == "both":
            mask = (batch.src.view(-1, 1) == history.src.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
            mask += (batch.dst.view(-1, 1) == history.dst.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
        else:
            raise NotImplementedError

        if self.profile_version == "max":
            profile_emb = (mask * history_emb).max(dim=1)[0]
        elif self.profile_version == "mean":
            profile_emb = (mask * history_emb).sum(dim=1) / (torch.count_nonzero(mask, dim=1) + 1)
        elif self.profile_version == "relu":
            profile_emb = (mask * history_emb).sum(dim=1) / (torch.count_nonzero(mask, dim=1) + 1)
            profile_emb = profile_emb.relu()
        elif self.profile_version == "tanh":
            profile_emb = torch.tanh((mask * history_emb).sum(dim=1) / (torch.count_nonzero(mask, dim=1) + 1))
        else:
            raise NotImplementedError

        return profile_emb

    def summarize_profile_v4(self, batch):
        current_time = batch.t.max()
        history = self.storage[self.storage.t <= current_time]
        if len(history) == 0:
            return torch.zeros((batch.src.size(0), self.model_dim), device=self.device)
        if len(history) > self.history_len:
            t0 = torch.topk(history.t, self.history_len)[0].min()
            history = history[history.t >= t0]

        if self.profile_on == "src":
            mask = (batch.src.view(-1, 1) == history.src.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
        elif self.profile_on == "dst":
            mask = (batch.dst.view(-1, 1) == history.dst.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
        elif self.profile_on == "both":
            mask = (batch.src.view(-1, 1) == history.src.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
            mask += (batch.dst.view(-1, 1) == history.dst.view(1, -1).repeat(len(batch), 1)).unsqueeze(-1)
        else:
            raise NotImplementedError

        if self.history_version == "dst-msg-t":
            dst_emb = self.dst_encoder(history.dst).unsqueeze(0).expand(len(batch), -1, -1)
            msg_emb = self.msg_encoder(history.msg).unsqueeze(0).expand(len(batch), -1, -1)
            time_emb = self.time_encoder(history.t).unsqueeze(0).expand(len(batch), -1, -1)
            if self.profile_version == "max":
                profile_emb = ((mask * dst_emb).max(dim=1)[0] + (mask * msg_emb).max(dim=1)[0] +
                               (mask * time_emb).max(dim=1)[0]) / 3
            elif self.profile_version == "mean":
                profile_emb = ((mask * dst_emb).sum(dim=1) + (mask * msg_emb).sum(dim=1) + (mask * time_emb).sum(
                    dim=1)) / (3 * (torch.count_nonzero(mask, dim=1) + 1))
            else:
                raise NotImplementedError
        elif self.history_version == "dst-msg":
            dst_emb = self.dst_encoder(history.dst).unsqueeze(0).expand(len(batch), -1, -1)
            msg_emb = self.msg_encoder(history.msg).unsqueeze(0).expand(len(batch), -1, -1)
            if self.profile_version == "max":
                profile_emb = ((mask * dst_emb).max(dim=1)[0] + (mask * msg_emb).max(dim=1)[0]) / 2
            elif self.profile_version == "mean":
                profile_emb = ((mask * dst_emb).sum(dim=1) + (mask * msg_emb).sum(dim=1)) / (2 * (torch.count_nonzero(mask, dim=1) + 1))
            else:
                raise NotImplementedError
        elif self.history_version == "dst":
            dst_emb = self.dst_encoder(history.dst).unsqueeze(0).expand(len(batch), -1, -1)
            if self.profile_version == "max":
                profile_emb = (mask * dst_emb).max(dim=1)[0]
            elif self.profile_version == "mean":
                profile_emb = (mask * dst_emb).sum(dim=1)/(torch.count_nonzero(mask, dim=1) + 1)
            else:
                raise NotImplementedError
        elif self.history_version == "msg":
            msg_emb = self.msg_encoder(history.msg).unsqueeze(0).expand(len(batch), -1, -1)
            if self.profile_version == "max":
                profile_emb = (mask * msg_emb).max(dim=1)[0]
            elif self.profile_version == "mean":
                profile_emb = (mask * msg_emb).sum(dim=1)/(torch.count_nonzero(mask, dim=1) + 1)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

        return profile_emb


class Readout(nn.Module):
    def __init__(self, in_features: int, out_features: int = 1, version=None):
        super().__init__()
        self.version = version
        if self.version == "profile":
            self.lin_profile = nn.Linear(in_features, in_features)
        elif self.version == "profile-dst":
            self.lin_profile = nn.Linear(in_features, in_features)
            self.lin_dst = nn.Linear(in_features, in_features)
        elif self.version == "profile-dst-msg":
            self.lin_profile = nn.Linear(in_features, in_features)
            self.lin_dst = nn.Linear(in_features, in_features)
            self.lin_msg = nn.Linear(in_features, in_features)
        elif self.version == "profile-dst-msg-t":
            self.lin_profile = nn.Linear(in_features, in_features)
            self.lin_dst = nn.Linear(in_features, in_features)
            self.lin_msg = nn.Linear(in_features, in_features)
            self.lin_t = nn.Linear(in_features, in_features)
        else:
            raise NotImplementedError

        self.lin_final = nn.Linear(in_features, out_features)

    def forward(self, z_profile, z_dst, z_msg, z_t):
        h = self.lin_profile(z_profile)
        if hasattr(self, "lin_dst"):
            h += self.lin_dst(z_dst)
        if hasattr(self, "lin_msg"):
            h += self.lin_msg(z_msg)
        if hasattr(self, "lin_t"):
            h += self.lin_t(z_t)
        return self.lin_final(h)


class TimeEncoder(nn.Module):
    def __init__(self, out_features: int, device, version="fix"):
        super().__init__()
        self.out_features = out_features
        self.lin = nn.Linear(1, out_features)
        self.device = device
        self.reset_parameters()

    # TODO: device management here is not perfect but works for the moment

    def reset_parameters(self) -> None:
        self.lin.weight = nn.Parameter(
            (torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.out_features, dtype=np.float32))).to(self.device).reshape(self.out_features, -1))
        self.lin.bias = nn.Parameter(torch.zeros(self.out_features, device=self.device))
        self.lin.weight.requires_grad = False
        self.lin.bias.requires_grad = False

    @torch.no_grad()
    def forward(self, t: Tensor) -> Tensor:
        output = torch.cos(self.lin(t.float().reshape((-1, 1))))
        return output


class XavierEmbedding(nn.Embedding):
    def __init__(self, shape: tuple, padding_idx=None):
        super().__init__(*shape, padding_idx=padding_idx)

    def reset_parameters(self) -> None:
        nn.init.xavier_normal_(self.weight)
        self._fill_padding_idx_with_zero()


class XavierLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int):
        super().__init__(in_features, out_features)

    def reset_parameters(self) -> None:
        nn.init.xavier_normal_(self.weight)
